{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "10ae2b14",
   "metadata": {},
   "source": [
    "# Sentence Embeddings with Hugging Face Transformers, Sentence Transformers and Amazon SageMaker - Custom Inference for creating document embeddings with Hugging Face's Transformers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a1644f1",
   "metadata": {},
   "source": [
    "Welcome to this getting started guide. We will use the Hugging Face Inference DLCs and Amazon SageMaker Python SDK to create a [real-time inference endpoint](https://docs.aws.amazon.com/sagemaker/latest/dg/realtime-endpoints.html) running a Sentence Transformers for document embeddings. Currently, the [SageMaker Hugging Face Inference Toolkit](https://github.com/aws/sagemaker-huggingface-inference-toolkit) supports the [pipeline feature](https://huggingface.co/transformers/main_classes/pipelines.html) from Transformers for zero-code deployment. This means you can run compatible Hugging Face Transformer models without providing pre- & post-processing code. Therefore we only need to provide an environment variable `HF_TASK` and `HF_MODEL_ID` when creating our endpoint and the Inference Toolkit will take care of it. This is a great feature if you are working with existing [pipelines](https://huggingface.co/transformers/main_classes/pipelines.html).\n",
    "\n",
    "If you want to run other tasks, such as creating document embeddings, you can the pre- and post-processing code yourself, via an `inference.py` script. The Hugging Face Inference Toolkit allows the user to override the default methods of the `HuggingFaceHandlerService`.\n",
    "\n",
    "The custom module can override the following methods:\n",
    "\n",
    "- `model_fn(model_dir)` overrides the default method for loading a model. The return value `model` will be used in the`predict_fn` for predictions.\n",
    "  -  `model_dir` is the path to your unzipped `model.tar.gz`.\n",
    "- `input_fn(input_data, content_type)` overrides the default method for pre-processing. The return value `data` will be used in `predict_fn` for predictions. The inputs are:\n",
    "    - `input_data` is the raw body of your request.\n",
    "    - `content_type` is the content type from the request header.\n",
    "- `predict_fn(processed_data, model)` overrides the default method for predictions. The return value `predictions` will be used in `output_fn`.\n",
    "  - `model` returned value from `model_fn` methond\n",
    "  - `processed_data` returned value from `input_fn` method\n",
    "- `output_fn(prediction, accept)` overrides the default method for post-processing. The return value `result` will be the response to your request (e.g.`JSON`). The inputs are:\n",
    "    - `predictions` is the result from `predict_fn`.\n",
    "    - `accept` is the return accept type from the HTTP Request, e.g. `application/json`.\n",
    "\n",
    "In this example are we going to use Sentence Transformers to create sentence embeddings using a mean pooling layer on the raw representation.\n",
    "\n",
    "*NOTE: You can run this demo in Sagemaker Studio, your local machine, or Sagemaker Notebook Instances*"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1e723ab",
   "metadata": {},
   "source": [
    "## Development Environment and Permissions\n",
    "\n",
    "### Installation \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69c59d90",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install sagemaker --upgrade"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce0ef431",
   "metadata": {},
   "source": [
    "Install `git` and `git-lfs`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96d8dfea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For notebook instances (Amazon Linux)\n",
    "!sudo yum update -y \n",
    "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash\n",
    "!sudo yum install git-lfs git -y\n",
    "# For other environments (Ubuntu)\n",
    "!sudo apt-get update -y \n",
    "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n",
    "!sudo apt-get install git-lfs git -y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9e4386d9",
   "metadata": {},
   "source": [
    "### Permissions\n",
    "\n",
    "_If you are going to use Sagemaker in a local environment (not SageMaker Studio or Notebook Instances). You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1c22e8d5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Couldn't call 'get_role' to get Role ARN from role name philippschmid to get Role path.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sagemaker role arn: arn:aws:iam::558105141721:role/sagemaker_execution_role\n",
      "sagemaker bucket: sagemaker-us-east-1-558105141721\n",
      "sagemaker session region: us-east-1\n"
     ]
    }
   ],
   "source": [
    "import sagemaker\n",
    "import boto3\n",
    "sess = sagemaker.Session()\n",
    "# sagemaker session bucket -> used for uploading data, models and logs\n",
    "# sagemaker will automatically create this bucket if it not exists\n",
    "sagemaker_session_bucket=None\n",
    "if sagemaker_session_bucket is None and sess is not None:\n",
    "    # set to default bucket if a bucket name is not given\n",
    "    sagemaker_session_bucket = sess.default_bucket()\n",
    "\n",
    "try:\n",
    "    role = sagemaker.get_execution_role()\n",
    "except ValueError:\n",
    "    iam = boto3.client('iam')\n",
    "    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
    "\n",
    "sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
    "\n",
    "print(f\"sagemaker role arn: {role}\")\n",
    "print(f\"sagemaker bucket: {sess.default_bucket()}\")\n",
    "print(f\"sagemaker session region: {sess.boto_region_name}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b0fc22f",
   "metadata": {},
   "source": [
    "## Create custom an `inference.py` script\n",
    "\n",
    "To use the custom inference script, you need to create an `inference.py` script. In our example, we are going to overwrite the `model_fn` to load our sentence transformer correctly and the `predict_fn` to apply mean pooling.\n",
    "\n",
    "We are going to use the [sentence-transformers/all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) model. It maps sentences & paragraphs to a 384 dimensional dense vector space and can be used for tasks like clustering or semantic search."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b4246c06",
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "3ce41529",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting code/inference.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile code/inference.py\n",
    "\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# Helper: Mean Pooling - Take attention mask into account for correct averaging\n",
    "def mean_pooling(model_output, attention_mask):\n",
    "    token_embeddings = model_output[0] #First element of model_output contains all token embeddings\n",
    "    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()\n",
    "    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)\n",
    "\n",
    "\n",
    "def model_fn(model_dir):\n",
    "  # Load model from HuggingFace Hub\n",
    "  tokenizer = AutoTokenizer.from_pretrained(model_dir)\n",
    "  model = AutoModel.from_pretrained(model_dir)\n",
    "  return model, tokenizer\n",
    "\n",
    "def predict_fn(data, model_and_tokenizer):\n",
    "    # destruct model and tokenizer\n",
    "    model, tokenizer = model_and_tokenizer\n",
    "    \n",
    "    # Tokenize sentences\n",
    "    sentences = data.pop(\"inputs\", data)\n",
    "    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')\n",
    "\n",
    "    # Compute token embeddings\n",
    "    with torch.no_grad():\n",
    "        model_output = model(**encoded_input)\n",
    "\n",
    "    # Perform pooling\n",
    "    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])\n",
    "\n",
    "    # Normalize embeddings\n",
    "    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)\n",
    "    \n",
    "    # return dictonary, which will be json serializable\n",
    "    return {\"vectors\": sentence_embeddings[0].tolist()}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "144d8ccb",
   "metadata": {},
   "source": [
    "## Create `model.tar.gz` with inference script and model \n",
    "\n",
    "To use our `inference.py` we need to bundle it into a `model.tar.gz` archive with all our model-artifcats, e.g. `pytorch_model.bin`. The `inference.py` script will be placed into a `code/` folder. We will use `git` and `git-lfs` to easily download our model from hf.co/models and upload it to Amazon S3 so we can use it when creating our SageMaker endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "952983b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "repository = \"sentence-transformers/all-MiniLM-L6-v2\"\n",
    "model_id=repository.split(\"/\")[-1]\n",
    "s3_location=f\"s3://{sess.default_bucket()}/custom_inference/{model_id}/model.tar.gz\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "374ff630",
   "metadata": {},
   "source": [
    "1. Download the model from hf.co/models with `git clone`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "b8452981",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Updated git hooks.\n",
      "Git LFS initialized.\n",
      "Cloning into 'all-MiniLM-L6-v2'...\n",
      "remote: Enumerating objects: 25, done.\u001b[K\n",
      "remote: Counting objects: 100% (25/25), done.\u001b[K\n",
      "remote: Compressing objects: 100% (23/23), done.\u001b[K\n",
      "remote: Total 25 (delta 3), reused 0 (delta 0)\u001b[K.00 KiB/s\n",
      "Unpacking objects: 100% (25/25), 308.60 KiB | 454.00 KiB/s, done.\n"
     ]
    }
   ],
   "source": [
    "!git lfs install\n",
    "!git clone https://huggingface.co/$repository"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "09a6f330",
   "metadata": {},
   "source": [
    "2. copy `inference.py`  into the `code/` directory of the model directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "6146af09",
   "metadata": {},
   "outputs": [],
   "source": [
    "!cp -r code/ $model_id/code/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "04e1395a",
   "metadata": {},
   "source": [
    "3. Create a `model.tar.gz` archive with all the model artifacts and the `inference.py` script.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "e65fd56e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/Users/philipp/.Trash/all-MiniLM-L6-v2/all-MiniLM-L6-v2\n",
      "a 1_Pooling\n",
      "a 1_Pooling/config.json\n",
      "a README.md\n",
      "a code\n",
      "a code/inference.py\n",
      "a config.json\n",
      "a config_sentence_transformers.json\n",
      "a data_config.json\n",
      "a modules.json\n",
      "a pytorch_model.bin\n",
      "a sentence_bert_config.json\n",
      "a special_tokens_map.json\n",
      "a tokenizer.json\n",
      "a tokenizer_config.json\n",
      "a train_script.py\n",
      "a vocab.txt\n"
     ]
    }
   ],
   "source": [
    "%cd $model_id\n",
    "!tar zcvf model.tar.gz *"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1c858560",
   "metadata": {},
   "source": [
    "4. Upload the `model.tar.gz` to Amazon S3:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c581bc40",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "upload: ./model.tar.gz to s3://sagemaker-us-east-1-558105141721/custom_inference/all-MiniLM-L6-v2/model.tar.gz\n"
     ]
    }
   ],
   "source": [
    "!aws s3 cp model.tar.gz $s3_location\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a146346",
   "metadata": {},
   "source": [
    "## Create custom `HuggingfaceModel` \n",
    "\n",
    "After we have created and uploaded our `model.tar.gz` archive to Amazon S3. Can we create a custom `HuggingfaceModel` class. This class will be used to create and deploy our SageMaker endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "1c5ba990",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------!"
     ]
    }
   ],
   "source": [
    "from sagemaker.huggingface.model import HuggingFaceModel\n",
    "\n",
    "\n",
    "# create Hugging Face Model Class\n",
    "huggingface_model = HuggingFaceModel(\n",
    "   model_data=s3_location,       # path to your model and script\n",
    "   role=role,                    # iam role with permissions to create an Endpoint\n",
    "   transformers_version=\"4.26\",  # transformers version used\n",
    "   pytorch_version=\"1.13\",        # pytorch version used\n",
    "   py_version='py39',            # python version used\n",
    ")\n",
    "\n",
    "# deploy the endpoint endpoint\n",
    "predictor = huggingface_model.deploy(\n",
    "    initial_instance_count=1,\n",
    "    instance_type=\"ml.g4dn.xlarge\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b6b3812f",
   "metadata": {},
   "source": [
    "## Request Inference Endpoint using the `HuggingfacePredictor`\n",
    "\n",
    "The `.deploy()` returns an `HuggingFacePredictor` object which can be used to request inference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "51c5366b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'vectors': [0.005078191868960857, -0.0036594511475414038, 0.016988741233944893, -0.0015786211006343365, 0.030203675851225853, 0.09331899881362915, -0.0235157310962677, 0.011795195750892162, 0.03421774506568909, -0.027907833456993103, -0.03260169178247452, 0.0679800882935524, 0.015223750844597816, 0.025948498398065567, -0.07854384928941727, -0.0023915462661534548, 0.10089637339115143, 0.0014981384156271815, -0.017778029665350914, 0.005812637507915497, 0.02445339597761631, -0.0710371807217598, 0.04755859822034836, 0.026360979303717613, -0.05716250091791153, -0.0940014198422432, 0.047949012368917465, 0.008600219152867794, 0.03297032043337822, -0.06984368711709976, -0.0552142858505249, -0.03234352916479111, -0.0003443364112172276, 0.012479404918849468, -0.07419367134571075, 0.08545409888029099, 0.019597113132476807, 0.005851477384567261, -0.08256848156452179, 0.010150186717510223, 0.028275227174162865, -0.0016121627995744348, 0.04174523428082466, -0.009756717830896378, 0.03546829894185066, -0.0673336461186409, 0.013293622992932796, -0.047809384763240814, -0.02249010093510151, 0.028243854641914368, -0.08043544739484787, -0.01009676605463028, -0.03514788672327995, -0.021383730694651604, -0.002246067626401782, -0.015066167339682579, 0.04234122484922409, -0.040479838848114014, 0.00787312351167202, -0.04465996101498604, 0.010779906995594501, 0.0038497159257531166, -0.027719097211956978, -0.007967316545546055, 0.02942546270787716, -0.012327964417636395, 0.0050182887353003025, 0.06450540572404861, 0.03108026832342148, 0.042792391031980515, 0.023805316537618637, -0.01616135612130165, 0.02578461915254593, -0.08669176697731018, -0.044727668166160583, 7.097257184796035e-05, -0.10924965143203735, -0.10867254436016083, -0.03139006346464157, -0.03511088714003563, 0.08570166677236557, -0.134019672870636, -0.0005924605648033321, 0.029533952474594116, 0.012721308507025242, 0.02152288891375065, 0.0707324892282486, -0.11056605726480484, -0.1083742305636406, 0.0982309952378273, -0.039475709199905396, -0.05996376648545265, -0.10398901998996735, 0.03040657937526703, -0.03018292225897312, -0.03471128270030022, -0.06378458440303802, 0.016372960060834885, 0.0583597756922245, 0.012307470664381981, 0.04363206401467323, -0.031246762722730637, -0.09203378111124039, -0.0062785972841084, 0.015498220920562744, -0.07184164226055145, 0.012648160569369793, 0.014564670622348785, -0.08191244304180145, 0.023379981517791748, -0.011096887290477753, 0.0394676998257637, -0.033372823148965836, 0.041654154658317566, 0.0863155946135521, 0.015705395489931107, 0.01734650880098343, 0.08271384239196777, 0.022032614797353745, 0.03559378534555435, 0.12214990705251694, 0.032827410846948624, 0.026021108031272888, -0.019847815856337547, 0.010051277466118336, -0.04892867058515549, -0.0174998976290226, -1.4977462088666326e-33, -0.01998828910291195, -0.020090218633413315, 0.009214007295668125, 0.029388802126049995, 0.01617312990128994, 0.003455288475379348, -0.07258066534996033, 0.049684278666973114, -0.06154271960258484, 0.05080917105078697, 0.05352963134646416, -0.011941409669816494, -0.0028067785315215588, -0.041576843708753586, -0.010775507427752018, 0.00046661923988722265, 0.004454561043530703, 0.030003147199749947, -0.0516991950571537, -0.030697643756866455, -0.07532348483800888, 0.05465441197156906, -0.0385969914495945, -0.04381357878446579, -0.03235914930701256, 0.017494583502411842, 0.005240216851234436, 0.06198848783969879, -0.03355488181114197, 0.011264801025390625, -0.02115759812295437, 0.00838891975581646, -0.058978889137506485, -0.00011408641876187176, 0.05079993978142738, 0.015300493687391281, -0.07043343037366867, -0.07872467488050461, 0.09050456434488297, 0.03952907398343086, -0.07477521151304245, 0.03615942969918251, -0.058201417326927185, 0.0326484851539135, -0.03198658302426338, 0.11224830150604248, -0.016622459515929222, 0.0504615381360054, -0.04651995375752449, 0.1277347207069397, 0.03776664286851883, 0.05948572978377342, 0.09149560332298279, -0.009857898578047752, 0.004627745598554611, 0.03188807889819145, 0.062271688133478165, -0.0659433975815773, 0.0032127737067639828, -0.13898129761219025, 0.026403773576021194, 0.08804035186767578, -0.05001967027783394, 0.05326379835605621, -0.02196440100669861, 0.07656972110271454, 0.013867619447410107, -0.016544628888368607, -0.009327870793640614, 0.021883144974708557, -0.1560947597026825, -0.07534021139144897, -0.01896633207798004, 0.012034989893436432, -0.07331383228302002, -0.04332052916288376, -0.03353505954146385, 0.007872307673096657, 0.16191385686397552, -0.058967869728803635, 0.024201923981308937, 0.011731469072401524, -0.002475024200975895, -0.060298558324575424, -0.023722389712929726, -0.04882300645112991, 0.000707246595993638, -0.018090907484292984, 0.07239993661642075, 0.07933493703603745, 0.054174549877643585, -0.03342485427856445, -0.007864750921726227, 0.06494550406932831, -0.08771026879549026, 1.13459770849573e-33, 0.06040865182876587, 0.006845973432064056, -0.09519106149673462, -0.004926742985844612, 0.02894597128033638, -0.0077415574342012405, -0.05669841915369034, -0.034497782588005066, 0.09411472827196121, 0.0011957630049437284, -0.03672650456428528, 0.023257385939359665, -0.029259465634822845, -0.004881837405264378, -0.034621454775333405, -0.1123257502913475, 0.041878167539834976, 0.01935793086886406, 0.019774673506617546, 0.0033800536766648293, 0.04810955002903938, -0.043293364346027374, -0.019849350675940514, -0.024460462853312492, 0.011674574576318264, 0.028871286660432816, -0.04594291001558304, -0.009591681882739067, -0.020649896934628487, -0.0767439752817154, 0.06008455529808998, -0.07102784514427185, -0.03325150907039642, -0.07066744565963745, -0.07285013049840927, 0.06852841377258301, 0.032675426453351974, -0.015307767316699028, -0.03120141103863716, -0.0008060619584284723, -0.012935955077409744, 0.01687614619731903, 0.010606919415295124, 0.05316408351063728, -0.016209596768021584, 0.05059502646327019, -0.016619250178337097, -0.003106643445789814, -0.09400973469018936, 0.02362005040049553, -0.1493453085422516, 0.03363995999097824, -0.013002770021557808, -0.0411999374628067, -0.03762894868850708, 0.01735512912273407, -0.02544626034796238, -0.015723178163170815, 0.007998578250408173, 0.04340173304080963, 0.006307568401098251, -0.031614888459444046, -0.03868135064840317, -0.11168476939201355, 0.04688170179724693, 0.02938792295753956, 0.007106451783329248, -0.023254472762346268, 0.006188348866999149, 0.032097551971673965, 0.02284681424498558, -0.020912854000926018, -0.016115304082632065, 0.006232560612261295, -0.06727242469787598, 0.0027730280999094248, -0.04707656428217888, -0.03735049441456795, 0.026144297793507576, -0.013619091361761093, -0.005712081212550402, -0.04333459213376045, -0.008567489683628082, -0.0026371825952082872, -0.04714951291680336, 0.1506747603416443, 0.060538701713085175, 0.015910591930150986, 0.0021603393834084272, 0.09120813012123108, 0.10193410515785217, 0.04816991090774536, 0.07890739291906357, -0.05583663284778595, -0.02227107249200344, -2.478202887346015e-08, -0.08490563929080963, 0.04434036836028099, 0.02475418709218502, -0.024806825444102287, 0.00536795100197196, -0.06101489067077637, 0.014922979287803173, 0.04093354195356369, 0.03936637192964554, 0.04489367827773094, 0.012824231758713722, -0.03051156736910343, 0.0662570372223854, 0.04904399439692497, 0.004838698077946901, 0.07400422543287277, 0.03470872715115547, 0.037787146866321564, -0.043043263256549835, 0.04372495785355568, 0.023403732106089592, 0.057728372514247894, 0.034502316266298294, -0.049777042120695114, -0.0041667199693620205, 0.06382499635219574, -0.007370579522103071, -0.002130263252183795, -0.04700297489762306, 0.10623563826084137, -5.87037175137084e-05, -0.012606821022927761, 0.03633716702461243, 0.024944987148046494, -0.06500178575515747, 0.07670733332633972, 0.01752745360136032, 0.019638163968920708, 0.05920606851577759, 0.021030694246292114, 0.033589065074920654, 0.014452814124524593, 0.030615368857979774, 0.13622330129146576, 0.0162414088845253, 0.07696809619665146, 0.10586545616388321, 0.06321518868207932, -0.06497083604335785, 0.0035124991554766893, 0.03836303576827049, -0.049263447523117065, -0.0939357802271843, 0.04310446232557297, 0.047002870589494705, 0.02352922037243843, 0.06475073844194412, 0.12606267631053925, -0.03936544433236122, 0.0033126939088106155, -0.005963532254099846, 0.01087606605142355, -0.006803632713854313, 0.05783495306968689]}\n"
     ]
    }
   ],
   "source": [
    "data = {\n",
    "  \"inputs\": \"the mesmerizing performances of the leads keep the film grounded and keep the audience riveted .\",\n",
    "}\n",
    "\n",
    "res = predictor.predict(data=data)\n",
    "print(res)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb10007d",
   "metadata": {},
   "source": [
    "### Delete model and endpoint\n",
    "\n",
    "To clean up, we can delete the model and endpoint."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "1e6fb7b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor.delete_model()\n",
    "predictor.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "c281c456f1b8161c8906f4af2c08ed2c40c50136979eaae69688b01f70e9f4a9"
  },
  "kernelspec": {
   "display_name": "conda_pytorch_p39",
   "language": "python",
   "name": "conda_pytorch_p39"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
